#include <utils/err.h>
#include <seminix/cache.h>
#include <seminix/syscall.h>
#include <seminix/tcb.h>
#include <seminix/param.h>
#include <cap/cnode.h>
#include <cap/rlimit.h>
#include <cap/ipc_buffer.h>
#include <cap/endpoint.h>
#include <libseminix/types.h>

static struct kmem_cache *cap_cnode_cachep;
static struct kmem_cache *cnode_cachep;

static __init int cnode_cap_init(void)
{
    cap_cnode_cachep = KMEM_CACHE(cap_cnode, SLAB_PANIC);
    cnode_cachep = KMEM_CACHE(cnode, SLAB_PANIC);
    return 0;
}
userver_initcall(cnode_cap_init)

static void free_cnodetable(cnode_t *cnodes)
{
    kfree(cnodes->full_bits);
    kfree(cnodes->cnode_root);
    kmem_cache_free(cnode_cachep, cnodes);
}

static cnode_t *alloc_cnodetable(int nr, int next_cd)
{
    int ret = -SERRNO_ENOMEM;
    cnode_t *cnodes;

    cnodes = kmem_cache_alloc(cnode_cachep, GFP_KERNEL | GFP_ZERO);
    if (!cnodes)
        goto out;

    cnodes->cnode_max = nr;
    cnodes->cnode_first = next_cd;
    cnodes->cnode_root = kcalloc(nr, sizeof (cap_t *), GFP_KERNEL);
    if (!cnodes->cnode_root)
        goto free_cnodes;

    cnodes->full_bits = kzalloc(max(nr / BITS_PER_BYTE, cache_line_size()), GFP_KERNEL);
    if (!cnodes->full_bits)
        goto free_cnode_root;
    cnodes->cnode_used = 0;

    return cnodes;

free_cnode_root:
    kfree(cnodes->cnode_root);
free_cnodes:
    kmem_cache_free(cnode_cachep, cnodes);
out:
    return ERR_PTR(ret);
}

static void copy_cnodetable(cnode_t *new_cnodes, cnode_t *old_cnodes)
{
    uint32_t cpy;

    assert(new_cnodes->cnode_max > old_cnodes->cnode_max);

    cpy = old_cnodes->cnode_max * sizeof (cap_t *);
    memcpy(new_cnodes->cnode_root, old_cnodes->cnode_root, cpy);

    cpy = max(old_cnodes->cnode_max / BITS_PER_BYTE, cache_line_size());
    memcpy(new_cnodes->full_bits, old_cnodes->full_bits, cpy);
}

static int find_next_cd(cnode_t *cnodes, int start)
{
    int first;
    int maxcd = cnodes->cnode_max;

    first = find_next_zero_bit(cnodes->full_bits, maxcd, start);
    if (first > maxcd)
        return maxcd;

    return first;
}

static int expand_cnodes_check(cnode_t *cnodes, int nr)
{
    if (nr < cnodes->cnode_max)
        return 0;
    return -SERRNO_EOVERFLOW;
}

static int __get_unused_cd_locked(cnode_t *cnodes, uint32_t start)
{
    int cd = start, error;

    if (cd < cnodes->cnode_first)
        cd = cnodes->cnode_first;

    if (cd < cnodes->cnode_max)
        cd = find_next_cd(cnodes, cd);

    error = expand_cnodes_check(cnodes, cd);
    if (error)
        return error;

    set_bit(cd, cnodes->full_bits);
    cnodes->cnode_first = cd + 1;
    cnodes->cnode_used++;

    return cd;
}

int cnode_get_unused_slot(cap_cnode_t *cap_cnode)
{
    int index;
    cnode_t *cnodes;

    spin_lock(cap_cnode->cnode_lock);
    cnodes = cap_cnode->cnode[0];
    index = __get_unused_cd_locked(cnodes, cnodes->cnode_first);
    spin_unlock(cap_cnode->cnode_lock);

    return index;
}

void cnode_cap_free_slot(cap_cnode_t *cap_cnode, int index)
{
    cnode_t *cnodes;

    spin_lock(cap_cnode->cnode_lock);
    cnodes = cap_cnode->cnode[0];
    BUG_ON(index >= cnodes->cnode_max);
    clear_bit(index, cnodes->full_bits);
    if (index < cnodes->cnode_first)
        cnodes->cnode_first = index;
    cnodes->cnode_root[index] = NULL;
    cnodes->cnode_used--;
    spin_unlock(cap_cnode->cnode_lock);
}

void cnode_cap_insert_slot(cap_cnode_t *cap_cnode, int index, cap_t *cap)
{
    cnode_t *cnodes;

    spin_lock(cap_cnode->cnode_lock);
    cnodes = cap_cnode->cnode[0];
    assert(!cnodes->cnode_root[index]);
    assert(test_bit(index, cnodes->full_bits));
    cnodes->cnode_root[index] = cap;
    spin_unlock(cap_cnode->cnode_lock);
    cap->cap_cnode = cap_cnode;
    cap->index = index;
}

static __always_inline void cnode_cap_lock(cap_t *cap)
{
    assert(cap_get_cap_type(cap) == cap_cnode_cap);
    spin_lock(CAP_CNODE_PTR(cap)->cnode_lock);
}

static __always_inline void cnode_cap_unlock(cap_t *cap)
{
    assert(cap_get_cap_type(cap) == cap_cnode_cap);
    spin_unlock(CAP_CNODE_PTR(cap)->cnode_lock);
}

static __always_inline int cnode_cap_max(cap_t *cap)
{
    assert(cap_get_cap_type(cap) == cap_cnode_cap);
    return CAP_CNODE_PTR(cap)->cnode[0]->cnode_max;
}

static __always_inline cap_t *cnode_cap_slot(cap_t *cap, int index)
{
    cap_t *slot_cap;

    assert(cap_get_cap_type(cap) == cap_cnode_cap);

    if (index >= cnode_cap_max(cap))
        return ERR_PTR(-SERRNO_EOVERFLOW);

    slot_cap = CAP_CNODE_PTR(cap)->cnode[0]->cnode_root[index];
    if (!slot_cap)
        return ERR_PTR(-SERRNO_EILLEGAL);
    return slot_cap;
}

cap_t *cnode_capget(int cd, int cap_type)
{
    int cnode_index, target_index, has_cnode;
    cap_t *src_cnode_cap, *cnode_cap, *target_cap, *tmp_cap = NULL;

    if (!current->cap_cnode)
        return ERR_PTR(-SERRNO_EILLEGAL);

    src_cnode_cap = &current->cap_cnode->cap;
    if (!capget_not_zero(src_cnode_cap)) {
        assert(cap_removing(src_cnode_cap));
        return ERR_PTR(-SERRNO_EREMOVING);
    }
    cnode_cap_lock(src_cnode_cap);
    if ((has_cnode = capdesc_has_cnode(cd))) {
        cnode_index = capdesc_get_cnode(cd);
        cnode_cap = cnode_cap_slot(src_cnode_cap, cnode_index);
        if (IS_ERR(cnode_cap))
            goto out;

        if (cap_get_cap_type(cnode_cap) != cap_cnode_cap) {
            cnode_cap = ERR_PTR(-SERRNO_EILLEGAL);
            goto out;
        }
        if (!capget_not_zero(cnode_cap)) {
            assert(cap_removing(cnode_cap));
            goto out;
        }
        cnode_cap_unlock(src_cnode_cap);
        tmp_cap = src_cnode_cap;
        src_cnode_cap = cnode_cap;
        cnode_cap_lock(src_cnode_cap);
    }

    target_index = capdesc_get_index(cd);
    target_cap = cnode_cap_slot(src_cnode_cap, target_index);
    if (IS_ERR(target_cap)) {
        cnode_cap = target_cap;
        goto out;
    }
    if (cap_type == CAPTYPE_NOTVERIFY) {
        if (seminix_cap_type_invalid(cap_get_cap_type(target_cap))) {
            cnode_cap = ERR_PTR(-SERRNO_EILLEGAL);
            goto out;
        }
    } else if (cap_get_cap_type(target_cap) != cap_type) {
        cnode_cap = ERR_PTR(-SERRNO_EILLEGAL);
        goto out;
    }

    assert(&target_cap->cap_cnode->cap == src_cnode_cap);
    if (!capget_not_zero(target_cap)) {
        assert(cap_removing(target_cap));
        cnode_cap = ERR_PTR(-SERRNO_EREMOVING);
        goto out;
    }
    if (has_cnode)
        cap_set_from_cnode(target_cap);
    cnode_cap_unlock(src_cnode_cap);

    return target_cap;
out:
    cnode_cap_unlock(src_cnode_cap);
    capput(src_cnode_cap);
    if (tmp_cap)
        capput(tmp_cap);
    return cnode_cap;
}

void cnode_capput(cap_t *cap)
{
    capput(cap);
    capput(&cap->cap_cnode->cap);
    if (cap_from_cnode(cap)) {
        cap_clear_from_cnode(cap);
        capput(&cap->cap_cnode->cap.cap_cnode->cap);
    }
}

cap_cnode_t *cnode_get(int cnode)
{
    int ret;
    cap_t *cap;
    cap_cnode_t *cap_cnode;

    cap_cnode = current->cap_cnode;
    if (!cap_cnode)
        return ERR_PTR(-SERRNO_EILLEGAL);

    if (!capget_not_zero(CAP_REF(cap_cnode))) {
        assert(cap_removing(CAP_REF(cap_cnode)));
        return ERR_PTR(-SERRNO_EREMOVING);
    }

    if (cnode == THIS_CNODE_CAPDESC)
        return cap_cnode;

    cnode_cap_lock(CAP_REF(cap_cnode));
    if (cnode >= cap_cnode->cnode[0]->cnode_max) {
        ret = -SERRNO_EOVERFLOW;
        goto out;
    }
    cap = cap_cnode->cnode[0]->cnode_root[cnode];
    if (!cap) {
        ret = -SERRNO_EILLEGAL;
        goto out;
    }
    if (cap_get_cap_type(cap) != cap_cnode_cap) {
        ret = -SERRNO_EILLEGAL;
        goto out;
    }
    if (!capget_not_zero(cap)) {
        assert(cap_removing(cap));
        ret = -SERRNO_EREMOVING;
        goto out;
    }
    cnode_cap_unlock(CAP_REF(cap_cnode));
    return CAP_CNODE_PTR(cap);
out:
    cnode_cap_unlock(CAP_REF(cap_cnode));
    return ERR_PTR(ret);
}

static cnode_t *cnode_expand_table_lock(cap_cnode_t *cap_cnode)
{
    cnode_t *new_cnodes, *cnodes = cap_cnode->cnode[0];
    int nr = cnodes->cnode_max * 2;

    if (nr > (int)CAPDESC_CNODE_MAX) {
        if (cnodes->cnode_max > ((int)CAPDESC_CNODE_MAX - 10))
            return ERR_PTR(-SERRNO_EOVERFLOW);
        nr = CAPDESC_CNODE_MAX - 10;
    }

    new_cnodes = alloc_cnodetable(nr, cnodes->cnode_first);
    if (!new_cnodes)
        return ERR_PTR(-SERRNO_ENOMEM);

    spin_lock(cap_cnode->cnode_lock);
    copy_cnodetable(new_cnodes, cnodes);

    return new_cnodes;
}

SYSCALL_DEFINE1(cnode_expand, int, cnode)
{
    int ret = 0;
    cap_t *cap;
    cnode_t *cnodes, *tmp;

    cap = cnode_capget(cnode, cap_cnode_cap);
    if (IS_ERR(cap))
        return PTR_ERR(cap);

    if (!cap_can_write(cap)) {
        ret = -SERRNO_EILLEGAL;
        goto out;
    }

    cnodes = cnode_expand_table_lock(CAP_CNODE_PTR(cap));
    if (IS_ERR(cnodes)) {
        ret = PTR_ERR(cnodes);
        goto out;
    }
    tmp = CAP_CNODE_PTR(cap)->cnode[0];
    CAP_CNODE_PTR(cap)->cnode[0] = cnodes;
    spin_unlock(CAP_CNODE_PTR(cap)->cnode_lock);
    free_cnodetable(tmp);
out:
    cnode_capput(cap);
    return ret;
}

static cap_t *do_dup(cap_rlimit_t *cap_rlimit, cap_t *parent, int rights)
{
    cap_t *new_cap;
    unsigned long limit = cap_dup_rlimit(parent);
    struct caprlimit *crlimit = &cap_rlimit->crlim[cap_get_cap_type(parent)];

    if (rlimit_overflow(limit, crlimit))
        return ERR_PTR(-SERRNO_EOVERFLOW);

    new_cap = cap_dup(parent, rights);
    if (IS_ERR(new_cap))
        return new_cap;

    cap_set_rlimit(new_cap, cap_rlimit);
    rlimit_up(new_cap, limit);

    return new_cap;
}

SYSCALL_DEFINE3(cnode_dup, int, cnode, int, cd, int, rights)
{
    int ret, index;
    cap_t *cap, *new_cap;
    cap_rlimit_t *cap_rlimit;
    cap_cnode_t *cap_cnode;

    cap_rlimit = rlimit_get();
    if (!cap_rlimit)
        return -SERRNO_EILLEGAL;

    cap_cnode = cnode_get(cnode);
    if (!cap_cnode) {
        ret = -SERRNO_EILLEGAL;
        goto put_rlimit;
    }

    /* 只能从当前 cnode dup */
    if (capdesc_has_cnode(cd)) {
        ret = -SERRNO_EILLEGAL;
        goto put_cnode;
    }
    cap = cnode_capget(cd, CAPTYPE_NOTVERIFY);
    if (IS_ERR(cap)) {
        ret = PTR_ERR(cap);
        goto put_cnode;
    }

    index = cnode_get_unused_slot(cap_cnode);
    if (index < 0) {
        ret = index;
        goto put_cnode_cap;
    }

    new_cap = do_dup(cap_rlimit, cap, rights);
    if (IS_ERR(new_cap)) {
        ret = PTR_ERR(new_cap);
        cnode_cap_free_slot(cap_cnode, index);
        goto put_cnode_cap;
    }

    cnode_cap_insert_slot(cap_cnode, index, new_cap);
    ret = capdesc_set_desc(cnode, index);
put_cnode_cap:
    cnode_capput(cap);
put_cnode:
    cnode_put(cap_cnode);
put_rlimit:
    rlimit_put(cap_rlimit);
    return ret;
}

SYSCALL_DEFINE2(cnode_move, int, cnode, int, cd)
{
    int ret, index;
    cap_t *cap;
    cap_cnode_t *cap_cnode;

    cap_cnode = cnode_get(cnode);
    if (!cap_cnode)
        return -SERRNO_EILLEGAL;

    cap = cnode_capget(cd, CAPTYPE_NOTVERIFY);
    if (IS_ERR(cap)) {
        ret = PTR_ERR(cap);
        goto put_cnode;
    }

    index = cnode_get_unused_slot(cap_cnode);
    if (index < 0) {
        ret = index;
        goto put_cnode_cap;
    }

    cnode_cap_free_slot_cap(cap);
    cnode_cap_insert_slot(cap_cnode, index, cap);
    ret = capdesc_set_desc(cnode, index);
put_cnode_cap:
    cnode_capput(cap);
put_cnode:
    cnode_put(cap_cnode);
    return ret;
}

SYSCALL_DEFINE2(cnode_revoke, int, parent, int, cd)
{
    int ret;
    cap_t *cap, *parent_cap;

    if (capdesc_has_cnode(parent))
        return -SERRNO_EILLEGAL;

    parent_cap = cnode_capget(parent, CAPTYPE_NOTVERIFY);
    if (IS_ERR(parent_cap)) {
        ret = PTR_ERR(parent_cap);
        goto out;
    }

    cap = cnode_capget(cd, CAPTYPE_NOTVERIFY);
    if (IS_ERR(cap)) {
        ret = PTR_ERR(cap);
        goto put_cnode_parent;
    }

    if (!cap_is_parent(parent_cap, cap)) {
        ret = -SERRNO_EILLEGAL;
        goto put_cnode_child;
    }

    ret = cap_revoke(parent_cap, cap);
put_cnode_child:
    cnode_capput(cap);
put_cnode_parent:
    cnode_capput(parent_cap);
out:
    return ret;
}

SYSCALL_DEFINE2(cnode_delete, int, parent,int, cd)
{
    int ret;
    cap_t *cap, *parent_cap;

    if (capdesc_has_cnode(parent))
        return -SERRNO_EILLEGAL;

    parent_cap = cnode_capget(parent, CAPTYPE_NOTVERIFY);
    if (IS_ERR(parent_cap)) {
        ret = PTR_ERR(parent_cap);
        goto out;
    }

    cap = cnode_capget(cd, CAPTYPE_NOTVERIFY);
    if (IS_ERR(cap)) {
        ret = PTR_ERR(cap);
        goto put_cnode_parent;
    }

    if (!cap_is_parent(parent_cap, cap)) {
        ret = -SERRNO_EILLEGAL;
        goto put_cnode_child;
    }

    ret = cap_delete(parent_cap, cap);
put_cnode_child:
    cnode_capput(cap);
put_cnode_parent:
    cnode_capput(parent_cap);
out:
    return ret;
}

static cap_t *cnode_cap_create(seminix_object_t *object)
{
    int ret;
    struct cap_cnode *cap_cnode;
    struct cnode *cnode;
    spinlock_t *lock;

    cap_cnode = kmem_cache_alloc(cap_cnode_cachep, GFP_KERNEL | GFP_ZERO);
    if (!cap_cnode)
        return ERR_PTR(-SERRNO_ENOMEM);

    lock = kmalloc(sizeof (*lock), GFP_KERNEL);
    if (!lock) {
        ret = -SERRNO_ENOMEM;
        goto free_cap_cnode;
    }
    spin_lock_init(lock);
    cap_cnode->cnode_lock = lock;

    cap_cnode->cnode = kzalloc(sizeof (cnode_t *), GFP_KERNEL);
    if (!cap_cnode->cnode) {
        ret = -SERRNO_ENOMEM;
        goto free_lock;
    }

    cnode = alloc_cnodetable(object->cnode.count, 0);
    if (IS_ERR(cnode)) {
        ret = PTR_ERR(cnode);
        goto free_cnode;
    }

    cap_cnode->cnode[0] = cnode;
    return &cap_cnode->cap;

free_cnode:
    kfree(cap_cnode->cnode);
free_lock:
    kfree(lock);
free_cap_cnode:
    kmem_cache_free(cap_cnode_cachep, cap_cnode);
    return ERR_PTR(ret);
}

static void cnode_cap_delete(cap_t *cap)
{
    cap_cnode_t *cap_cnode = CAP_CNODE_PTR(cap);

    assert(cap_get_cap_type(cap) == cap_cnode_cap);

    free_cnodetable(cap_cnode->cnode[0]);
    kfree(cap_cnode->cnode);
    kfree(cap_cnode->cnode_lock);
    kmem_cache_free(cap_cnode_cachep, cap_cnode);
}

static cap_t *cnode_cap_dup(cap_t *cap)
{
    struct cap_cnode *new_cap_cnode, *cap_cnode = CAP_CNODE_PTR(cap);

    assert(cap_get_cap_type(cap) == cap_cnode_cap);

    new_cap_cnode = kmem_cache_alloc(cap_cnode_cachep, GFP_KERNEL | GFP_ZERO);
    if (!new_cap_cnode)
        return ERR_PTR(-SERRNO_ENOMEM);

    spin_lock(cap_cnode->cnode_lock);
    new_cap_cnode->cnode_lock = cap_cnode->cnode_lock;
    new_cap_cnode->cnode = cap_cnode->cnode;
    new_cap_cnode->cnode[0] = cap_cnode->cnode[0];
    spin_unlock(cap_cnode->cnode_lock);

    return CAP_REF(new_cap_cnode);
}

static void cnode_cap_revoke(cap_t *cap)
{
    cap_cnode_t *cap_cnode = ((cap_cnode_t *)cap);

    assert(cap_get_cap_type(cap) == cap_cnode_cap);

    kmem_cache_free(cap_cnode_cachep, cap_cnode);
}

const struct cap_ops cnode_cap_ops __ro_after_init = {
    .cap_create = cnode_cap_create,
    .cap_delete = cnode_cap_delete,
    .cap_dup = cnode_cap_dup,
    .cap_revoke = cnode_cap_revoke,
};
